import os
import korvus
import pytest
from multiprocessing import Pool
from typing import List, Dict, Any
import asyncio

####################################################################################
####################################################################################
## PLEASE BE AWARE THESE TESTS DO INVOLVE CHECKS ON LAZILY CREATED DATABASE ITEMS ##
## IF ANY OF THE COLLECTION NAMES ALREADY EXIST, SOME TESTS MAY FAIL              ##
## THIS DOES NOT MEAN THE SDK IS BROKEN. PLEASE CLEAR YOUR DATABASE INSTANCE      ##
## BEFORE RUNNING ANY TESTS                                                       ##
####################################################################################
####################################################################################

korvus.init_logger()


def generate_dummy_documents(count: int) -> List[Dict[str, Any]]:
    dummy_documents = []
    for i in range(count):
        dummy_documents.append(
            {
                "id": i,
                "title": "Test Document {}".format(i),
                "body": "Test body {}".format(i),
                "text": "This is a test document: {}".format(i),
                "project": "a10",
                "floating_uuid": i * 1.01,
                "uuid": i * 10,
                "test": None,
                "name": "Test Document {}".format(i),
            }
        )
    return dummy_documents


###################################################
## Test the API exposed is correct ################
###################################################


def test_can_create_collection():
    collection = korvus.Collection(name="test_p_c_tscc_0")
    assert collection is not None


def test_can_create_model():
    model = korvus.Model()
    assert model is not None


def test_can_create_splitter():
    splitter = korvus.Splitter()
    assert splitter is not None


def test_can_create_pipeline():
    pipeline = korvus.Pipeline("test_p_p_tccp_0", {})
    assert pipeline is not None


def test_can_create_single_field_pipeline():
    model = korvus.Model()
    splitter = korvus.Splitter()
    pipeline = korvus.SingleFieldPipeline("test_p_p_tccsfp_0", model, splitter, {})
    assert pipeline is not None


def test_can_create_builtins():
    builtins = korvus.Builtins()
    assert builtins is not None


@pytest.mark.asyncio
async def test_can_embed_with_builtins():
    builtins = korvus.Builtins()
    result = await builtins.embed("intfloat/e5-small-v2", "test")
    assert result is not None


@pytest.mark.asyncio
async def test_can_embed_batch_with_builtins():
    builtins = korvus.Builtins()
    result = await builtins.embed_batch("intfloat/e5-small-v2", ["test"])
    assert result is not None


###################################################
## Test searches ##################################
###################################################


@pytest.mark.asyncio
async def test_can_search():
    pipeline = korvus.Pipeline(
        "test_p_p_tcs_0",
        {
            "title": {
                "semantic_search": {
                    "model": "intfloat/e5-small-v2",
                    "parameters": {"prompt": "passage: "},
                }
            },
            "body": {
                "splitter": {"model": "recursive_character"},
                "semantic_search": {
                    "model": "text-embedding-ada-002",
                    "source": "openai",
                },
                "full_text_search": {"configuration": "english"},
            },
        },
    )
    collection = korvus.Collection("test_p_c_tsc_13")
    await collection.add_pipeline(pipeline)
    await collection.upsert_documents(generate_dummy_documents(5))
    results = await collection.search(
        {
            "query": {
                "full_text_search": {"body": {"query": "Test", "boost": 1.2}},
                "semantic_search": {
                    "title": {
                        "query": "This is a test",
                        "parameters": {"prompt": "passage: "},
                        "boost": 2.0,
                    },
                    "body": {"query": "This is the body test", "boost": 1.01},
                },
                "filter": {"id": {"$gt": 1}},
            },
            "limit": 5,
        },
        pipeline,
    )
    ids = [result["id"] for result in results["results"]]
    assert ids == [3, 5, 4]
    await collection.archive()


###################################################
## Test various vector searches ###################
###################################################


@pytest.mark.asyncio
async def test_can_vector_search():
    pipeline = korvus.Pipeline(
        "test_p_p_tcvs_0",
        {
            "title": {
                "semantic_search": {
                    "model": "intfloat/e5-small-v2",
                    "parameters": {"prompt": "passage: "},
                },
                "full_text_search": {"configuration": "english"},
            },
            "text": {
                "splitter": {"model": "recursive_character"},
                "semantic_search": {
                    "model": "intfloat/e5-small-v2",
                    "parameters": {"prompt": "passage: "},
                },
            },
        },
    )
    collection = korvus.Collection("test_p_c_tcvs_3")
    await collection.add_pipeline(pipeline)
    await collection.upsert_documents(generate_dummy_documents(5))
    results = await collection.vector_search(
        {
            "query": {
                "fields": {
                    "title": {
                        "query": "Test document: 2",
                        "parameters": {"prompt": "passage: "},
                        "full_text_filter": "test",
                    },
                    "text": {
                        "query": "Test document: 2",
                        "parameters": {"prompt": "passage: "},
                    },
                },
                "filter": {"id": {"$gt": 2}},
            },
            "limit": 5,
        },
        pipeline,
    )
    ids = [result["document"]["id"] for result in results]
    assert ids == [3, 3, 4, 4]
    await collection.archive()


@pytest.mark.asyncio
async def test_can_vector_search_with_query_builder():
    model = korvus.Model("intfloat/e5-small-v2", "korvus", {"prompt": "passage: "})
    splitter = korvus.Splitter()
    pipeline = korvus.SingleFieldPipeline("test_p_p_tcvswqb_1", model, splitter)
    collection = korvus.Collection(name="test_p_c_tcvswqb_5")
    await collection.upsert_documents(generate_dummy_documents(3))
    await collection.add_pipeline(pipeline)
    results = (
        await collection.query()
        .vector_recall("Here is some query", pipeline)
        .limit(10)
        .fetch_all()
    )
    ids = [document["id"] for (_, _, document) in results]
    assert ids == [1, 2, 0]
    await collection.archive()


@pytest.mark.asyncio
async def test_can_vector_search_with_rerank():
    pipeline = korvus.Pipeline(
        "test_p_p_tcvswr_0",
        {
            "title": {
                "semantic_search": {
                    "model": "intfloat/e5-small-v2",
                    "parameters": {"prompt": "passage: "},
                },
                "full_text_search": {"configuration": "english"},
            },
            "text": {
                "splitter": {"model": "recursive_character"},
                "semantic_search": {
                    "model": "intfloat/e5-small-v2",
                    "parameters": {"prompt": "passage: "},
                },
            },
        },
    )
    collection = korvus.Collection("test_p_c_tcvs_3")
    await collection.add_pipeline(pipeline)
    await collection.upsert_documents(generate_dummy_documents(5))
    results = await collection.vector_search(
        {
            "query": {
                "fields": {
                    "title": {
                        "query": "Test document: 2",
                        "parameters": {"prompt": "passage: "},
                        "full_text_filter": "test",
                    },
                    "text": {
                        "query": "Test document: 2",
                        "parameters": {"prompt": "passage: "},
                    },
                },
                "filter": {"id": {"$gt": 2}},
            },
            "rerank": {
                "model": "mixedbread-ai/mxbai-rerank-base-v1",
                "query": "Test query",
                "num_documents_to_rerank": 100,
            },
            "limit": 5,
        },
        pipeline,
    )
    ids = [result["document"]["id"] for result in results]
    assert ids == [3, 3, 4, 4]
    await collection.archive()


###################################################
## Test RAG #######################################
###################################################


@pytest.mark.asyncio
async def test_can_rag():
    pipeline = korvus.Pipeline(
        "1",
        {
            "body": {
                "splitter": {"model": "recursive_character"},
                "semantic_search": {
                    "model": "intfloat/e5-small-v2",
                    "parameters": {"prompt": "passage: "},
                },
            },
        },
    )
    collection = korvus.Collection("test_p_c_cr")
    await collection.add_pipeline(pipeline)
    await collection.upsert_documents(generate_dummy_documents(5))
    results = await collection.rag(
        {
            "CONTEXT": {
                "vector_search": {
                    "query": {
                        "fields": {
                            "body": {
                                "query": "test",
                                "parameters": {"prompt": "query: "},
                            },
                        },
                    },
                    "document": {"keys": ["id"]},
                    "limit": 5,
                },
                "aggregate": {"join": "\n"},
            },
            "completion": {
                "model": "meta-llama/Meta-Llama-3-8B-Instruct",
                "prompt": "Some text with {CONTEXT}",
                "max_tokens": 10,
            },
        },
        pipeline,
    )
    assert len(results["rag"][0]) > 0
    assert len(results["sources"]["CONTEXT"]) > 0
    await collection.archive()


@pytest.mark.asyncio
async def test_can_rag_stream():
    pipeline = korvus.Pipeline(
        "1",
        {
            "body": {
                "splitter": {"model": "recursive_character"},
                "semantic_search": {
                    "model": "intfloat/e5-small-v2",
                    "parameters": {"prompt": "passage: "},
                },
            },
        },
    )
    collection = korvus.Collection("test_p_c_crs")
    await collection.add_pipeline(pipeline)
    await collection.upsert_documents(generate_dummy_documents(5))
    results = await collection.rag_stream(
        {
            "CONTEXT": {
                "vector_search": {
                    "query": {
                        "fields": {
                            "body": {
                                "query": "test",
                                "parameters": {"prompt": "query: "},
                            },
                        },
                    },
                    "document": {"keys": ["id"]},
                    "limit": 5,
                },
                "aggregate": {"join": "\n"},
            },
            "completion": {
                "model": "meta-llama/Meta-Llama-3-8B-Instruct",
                "prompt": "Some text with {CONTEXT}",
                "max_tokens": 10,
            },
        },
        pipeline,
    )
    async for c in results.stream():
        assert len(c) > 0
    await collection.archive()


###################################################
## Test document related functions ################
###################################################


@pytest.mark.asyncio
async def test_upsert_and_get_documents():
    collection = korvus.Collection("test_p_c_tuagd_2")
    await collection.upsert_documents(generate_dummy_documents(10))
    documents = await collection.get_documents()
    assert len(documents) == 10
    documents = await collection.get_documents(
        {"offset": 1, "limit": 2, "filter": {"id": {"$gt": 0}}}
    )
    assert len(documents) == 2 and documents[0]["document"]["id"] == 2
    last_row_id = documents[-1]["row_id"]
    documents = await collection.get_documents(
        {
            "filter": {
                "id": {"$lt": 7},
            },
            "last_row_id": last_row_id,
        }
    )
    assert len(documents) == 3 and documents[0]["document"]["id"] == 4
    await collection.archive()


@pytest.mark.asyncio
async def test_delete_documents():
    collection = korvus.Collection("test_p_c_tdd_1")
    await collection.upsert_documents(generate_dummy_documents(3))
    await collection.delete_documents(
        {
            "id": {"$gte": 2},
        }
    )
    documents = await collection.get_documents()
    assert len(documents) == 2 and documents[0]["document"]["id"] == 0
    await collection.archive()


@pytest.mark.asyncio
async def test_order_documents():
    collection = korvus.Collection("test_p_c_tod_0")
    await collection.upsert_documents(generate_dummy_documents(3))
    documents = await collection.get_documents({"order_by": {"id": "desc"}})
    assert len(documents) == 3
    assert documents[0]["document"]["id"] == 2
    await collection.archive()


###################################################
## Transformer Pipeline Tests #####################
###################################################


@pytest.mark.asyncio
async def test_transformer_pipeline():
    t = korvus.TransformerPipeline(
        "text-generation", "meta-llama/Meta-Llama-3-8B-Instruct"
    )
    it = await t.transform(["AI is going to"], {"max_new_tokens": 5})
    assert len(it) > 0


@pytest.mark.asyncio
async def test_transformer_pipeline_stream():
    t = korvus.TransformerPipeline(
        "text-generation", "meta-llama/Meta-Llama-3-8B-Instruct"
    )
    it = await t.transform_stream("AI is going to", {"max_tokens": 5})
    total = []
    async for c in it:
        total.append(c)
    assert len(total) > 0


###################################################
## OpenSourceAI tests ###########################
###################################################


def test_open_source_ai_create():
    client = korvus.OpenSourceAI()
    results = client.chat_completions_create(
        "meta-llama/Meta-Llama-3-8B-Instruct",
        [
            {
                "role": "system",
                "content": "You are a friendly chatbot who always responds in the style of a pirate",
            },
            {
                "role": "user",
                "content": "How many helicopters can a human eat in one sitting?",
            },
        ],
        max_tokens=10,
        temperature=0.85,
    )
    assert len(results["choices"]) > 0


@pytest.mark.asyncio
async def test_open_source_ai_create_async():
    client = korvus.OpenSourceAI()
    results = await client.chat_completions_create_async(
        "meta-llama/Meta-Llama-3-8B-Instruct",
        [
            {
                "role": "system",
                "content": "You are a friendly chatbot who always responds in the style of a pirate",
            },
            {
                "role": "user",
                "content": "How many helicopters can a human eat in one sitting?",
            },
        ],
        max_tokens=10,
        temperature=0.85,
    )
    assert len(results["choices"]) > 0


def test_open_source_ai_create_stream():
    client = korvus.OpenSourceAI()
    results = client.chat_completions_create_stream(
        "meta-llama/Meta-Llama-3-8B-Instruct",
        [
            {
                "role": "system",
                "content": "You are a friendly chatbot who always responds in the style of a pirate",
            },
            {
                "role": "user",
                "content": "How many helicopters can a human eat in one sitting?",
            },
        ],
        temperature=0.85,
        n=3,
    )
    output = []
    for c in results:
        output.append(c["choices"])
    assert len(output) > 0


@pytest.mark.asyncio
async def test_open_source_ai_create_stream_async():
    client = korvus.OpenSourceAI()
    results = await client.chat_completions_create_stream_async(
        "meta-llama/Meta-Llama-3-8B-Instruct",
        [
            {
                "role": "system",
                "content": "You are a friendly chatbot who always responds in the style of a pirate",
            },
            {
                "role": "user",
                "content": "How many helicopters can a human eat in one sitting?",
            },
        ],
        temperature=0.85,
        n=3,
    )
    output = []
    async for c in results:
        output.append(c["choices"])
    assert len(output) > 0


###################################################
## Migration tests ################################
###################################################


@pytest.mark.asyncio
async def test_migrate():
    await korvus.migrate()


###################################################
## Test with multiprocessing ######################
###################################################


# def vector_search(collection_name, pipeline_name):
#     collection = korvus.Collection(collection_name)
#     pipeline = korvus.Pipeline(pipeline_name)
#     result = asyncio.run(
#         collection.query()
#         .vector_recall("Here is some query", pipeline)
#         .limit(10)
#         .fetch_all()
#     )
#     print(result)
#     return [0, 1, 2]


# @pytest.mark.asyncio
# async def test_multiprocessing():
#     collection_name = "test_p_p_tm_1"
#     pipeline_name = "test_p_c_tm_4"
#
#     model = korvus.Model()
#     splitter = korvus.Splitter()
#     pipeline = korvus.Pipeline(pipeline_name, model, splitter)
#
#     collection = korvus.Collection(collection_name)
#     await collection.upsert_documents(generate_dummy_documents(3))
#     await collection.add_pipeline(pipeline)
#
#     with Pool(5) as p:
#         results = p.starmap(
#             vector_search, [(collection_name, pipeline_name) for _ in range(5)]
#         )
#         for x in results:
#             print(x)
#             assert len(x) == 3
#
#     await collection.archive()
